{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2KROuTZVuhrp"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "6aEVQQ403kzs"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P5VpOpSivqgv"
      },
      "source": [
        "# `fit()`の処理をカスタマイズする"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vApNeEfvLLc4"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\">TensorFlow.org で表示</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/guide/keras/customizing_what_happens_in_fit.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Google Colab で実行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/guide/keras/customizing_what_happens_in_fit.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\"> GitHub でソースを表示する</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/guide/keras/customizing_what_happens_in_fit.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">ノートブックをダウンロード</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TihRiHIKeeIz"
      },
      "source": [
        "## はじめに\n",
        "\n",
        "教師あり学習をする場合、`fit()`を使用すると全てがスムーズに動作します。\n",
        "\n",
        "独自のトレーニングループを新規で書く必要がある場合には、`GradientTape`を使用すると、細部までコントロールすることができます。\n",
        "\n",
        "しかし、カスタムトレーニングアルゴリズムが必要で、なおかつコールバック、組み込み分散サポート、ステップ結合など、`fit()`の便利な機能を利用したい場合には、どうしたらよいでしょうか？\n",
        "\n",
        "Keras の核となる原則は、**複雑性のプログレッシブディスクロージャ―**です。常に低レベルのワークフローに段階的に入ることが可能です。高レベルの機能性がユースケースと完全に一致しない場合でも、急激に性能が落ちるようなことはありません。相応の高レベルの利便性を維持しながら細部をよりコントロールできるはずです。\n",
        "\n",
        "`fit()`の動作をカスタマイズする必要がある場合は、**`Model`クラスのトレーニングステップ関数をオーバーライド**する必要があります。これはデータのバッチごとに`fit()`に呼び出される関数です。これによって、通常通りの`fit()`の呼び出しが可能になり、独自の学習アルゴリズムが実行されます。\n",
        "\n",
        "このパターンは Functional API を使用したモデル構築を妨げるものではないことに注意してください。これは、`Sequential`モデル、Functional API モデル、サブクラス化されたモデルのいずれを構築する場合にも適用可能です。\n",
        "\n",
        "では、その仕組みを見ていきましょう。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XFRryV6yxq2Z"
      },
      "source": [
        "## セットアップ\n",
        "\n",
        "TensorFlow 2.2 以降が必要です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BxGJZEXaWrLM"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "from tensorflow import keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1yZO4J3zyOfz"
      },
      "source": [
        "## 最初の簡単な例\n",
        "\n",
        "簡単な例から始めてみましょう。\n",
        "\n",
        "- `keras.Model`をサブクラス化する新しいクラスを作成します。\n",
        "- `train_step(self, data)`メソッドをオーバーライドするだけです。\n",
        "- メトリック名（損失を含む）をマッピングするディクショナリを現在の値に返します。\n",
        "\n",
        "トレーニングデータとして fit() に渡されるのが、入力引数`data`です。\n",
        "\n",
        "- `fit(x, y, ...)`を呼び出して Numpy 配列を渡すと、`data`はタプル`(x, y)`になります。\n",
        "- `fit(dataset, ...)`を呼び出して`tf.data.Dataset`を渡すと、`data`は各バッチで`dataset`により生成されたものになります。\n",
        "\n",
        "`train_step`メソッドの本体には、お馴染みの定期的なトレーニング更新を実装します。重要なのは、**損失を`self.compiled_loss`で計算する**ため、`compile()`に渡された損失関数をラップしていることです。\n",
        "\n",
        "同様に、`self.compiled_metrics.update_state(y, y_pred)`を呼び出して`compile()`で渡されたメトリクスの状態を更新し、最後に`self.metrics`の結果を照会して現在の値を取得します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sg0aNp6yuNUs"
      },
      "outputs": [],
      "source": [
        "class CustomModel(keras.Model):\n",
        "    def train_step(self, data):\n",
        "        # Unpack the data. Its structure depends on your model and\n",
        "        # on what you pass to `fit()`.\n",
        "        x, y = data\n",
        "\n",
        "        with tf.GradientTape() as tape:\n",
        "            y_pred = self(x, training=True)  # Forward pass\n",
        "            # Compute the loss value\n",
        "            # (the loss function is configured in `compile()`)\n",
        "            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)\n",
        "\n",
        "        # Compute gradients\n",
        "        trainable_vars = self.trainable_variables\n",
        "        gradients = tape.gradient(loss, trainable_vars)\n",
        "        # Update weights\n",
        "        self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n",
        "        # Update metrics (includes the metric that tracks the loss)\n",
        "        self.compiled_metrics.update_state(y, y_pred)\n",
        "        # Return a dict mapping metric names to current value\n",
        "        return {m.name: m.result() for m in self.metrics}\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YEdOFRbXmA4d"
      },
      "source": [
        "これを試してみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1wDUe4ReTaVi"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "# Construct and compile an instance of CustomModel\n",
        "inputs = keras.Input(shape=(32,))\n",
        "outputs = keras.layers.Dense(1)(inputs)\n",
        "model = CustomModel(inputs, outputs)\n",
        "model.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n",
        "\n",
        "# Just use `fit` as usual\n",
        "x = np.random.random((1000, 32))\n",
        "y = np.random.random((1000, 1))\n",
        "model.fit(x, y, epochs=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tQSwBvcGIeZk"
      },
      "source": [
        "## 低レベルにする\n",
        "\n",
        "当然ながら、`compile()`で損失関数を渡すことを省略し、その代わりに`train_step`で全てを*手動で*実行することは可能です。これはメトリクスの場合でも同様です。オプティマイザの構成に`compile()`のみを使用した、低レベルの例を次に示します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9UnwB6gdESVw"
      },
      "outputs": [],
      "source": [
        "mae_metric = keras.metrics.MeanAbsoluteError(name=\"mae\")\n",
        "loss_tracker = keras.metrics.Mean(name=\"loss\")\n",
        "\n",
        "\n",
        "class CustomModel(keras.Model):\n",
        "    def train_step(self, data):\n",
        "        x, y = data\n",
        "\n",
        "        with tf.GradientTape() as tape:\n",
        "            y_pred = self(x, training=True)  # Forward pass\n",
        "            # Compute our own loss\n",
        "            loss = keras.losses.mean_squared_error(y, y_pred)\n",
        "\n",
        "        # Compute gradients\n",
        "        trainable_vars = self.trainable_variables\n",
        "        gradients = tape.gradient(loss, trainable_vars)\n",
        "\n",
        "        # Update weights\n",
        "        self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n",
        "\n",
        "        # Compute our own metrics\n",
        "        loss_tracker.update_state(loss)\n",
        "        mae_metric.update_state(y, y_pred)\n",
        "        return {\"loss\": loss_tracker.result(), \"mae\": mae_metric.result()}\n",
        "\n",
        "\n",
        "# Construct an instance of CustomModel\n",
        "inputs = keras.Input(shape=(32,))\n",
        "outputs = keras.layers.Dense(1)(inputs)\n",
        "model = CustomModel(inputs, outputs)\n",
        "\n",
        "# We don't passs a loss or metrics here.\n",
        "model.compile(optimizer=\"adam\")\n",
        "\n",
        "# Just use `fit` as usual -- you can use callbacks, etc.\n",
        "x = np.random.random((1000, 32))\n",
        "y = np.random.random((1000, 1))\n",
        "model.fit(x, y, epochs=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WN0qnQacU9u2"
      },
      "source": [
        "このセットアップでは、各エポックの後に、またはトレーニングと評価の間に、メトリクス上で手動で`reset_states()`を呼び出す必要があることに注意してください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fnMF4QYQFNj1"
      },
      "source": [
        "## `sample_weight`と`class_weight`をサポートする\n",
        "\n",
        "最初の基本的な例には、サンプルの重み付けに関する言及が一切なかったことにお気づきでしょうか。`fit()`引数の`sample_weight`と`class_weight`をサポートする場合には、以下のようにします。\n",
        "\n",
        "- `data`引数から`sample_weight`をアンパックする\n",
        "- それを`compiled_loss`と`compiled_metrics`に渡す（もちろん、 損失とメトリクスが`compile()`に依存しない場合は手動での適用が可能）\n",
        "- これだけで完了です。これがリストです。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tE4yrX22rlL4"
      },
      "outputs": [],
      "source": [
        "class CustomModel(keras.Model):\n",
        "    def train_step(self, data):\n",
        "        # Unpack the data. Its structure depends on your model and\n",
        "        # on what you pass to `fit()`.\n",
        "        if len(data) == 3:\n",
        "            x, y, sample_weight = data\n",
        "        else:\n",
        "            x, y = data\n",
        "\n",
        "        with tf.GradientTape() as tape:\n",
        "            y_pred = self(x, training=True)  # Forward pass\n",
        "            # Compute the loss value.\n",
        "            # The loss function is configured in `compile()`.\n",
        "            loss = self.compiled_loss(\n",
        "                y,\n",
        "                y_pred,\n",
        "                sample_weight=sample_weight,\n",
        "                regularization_losses=self.losses,\n",
        "            )\n",
        "\n",
        "        # Compute gradients\n",
        "        trainable_vars = self.trainable_variables\n",
        "        gradients = tape.gradient(loss, trainable_vars)\n",
        "\n",
        "        # Update weights\n",
        "        self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n",
        "\n",
        "        # Update the metrics.\n",
        "        # Metrics are configured in `compile()`.\n",
        "        self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight)\n",
        "\n",
        "        # Return a dict mapping metric names to current value.\n",
        "        # Note that it will include the loss (tracked in self.metrics).\n",
        "        return {m.name: m.result() for m in self.metrics}\n",
        "\n",
        "\n",
        "# Construct and compile an instance of CustomModel\n",
        "inputs = keras.Input(shape=(32,))\n",
        "outputs = keras.layers.Dense(1)(inputs)\n",
        "model = CustomModel(inputs, outputs)\n",
        "model.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n",
        "\n",
        "# You can now use sample_weight argument\n",
        "x = np.random.random((1000, 32))\n",
        "y = np.random.random((1000, 1))\n",
        "sw = np.random.random((1000, 1))\n",
        "model.fit(x, y, sample_weight=sw, epochs=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j0uOhTfBjhYX"
      },
      "source": [
        "## 独自の評価ステップを提供する\n",
        "\n",
        "`model.evaluate()`の呼び出しに同じことをする場合はどうしたらよいでしょう？その場合は、全く同じ方法で`test_step`をオーバーライドします。これは次のようになります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vaogkBppfg2t"
      },
      "outputs": [],
      "source": [
        "class CustomModel(keras.Model):\n",
        "    def test_step(self, data):\n",
        "        # Unpack the data\n",
        "        x, y = data\n",
        "        # Compute predictions\n",
        "        y_pred = self(x, training=False)\n",
        "        # Updates the metrics tracking the loss\n",
        "        self.compiled_loss(y, y_pred, regularization_losses=self.losses)\n",
        "        # Update the metrics.\n",
        "        self.compiled_metrics.update_state(y, y_pred)\n",
        "        # Return a dict mapping metric names to current value.\n",
        "        # Note that it will include the loss (tracked in self.metrics).\n",
        "        return {m.name: m.result() for m in self.metrics}\n",
        "\n",
        "\n",
        "# Construct an instance of CustomModel\n",
        "inputs = keras.Input(shape=(32,))\n",
        "outputs = keras.layers.Dense(1)(inputs)\n",
        "model = CustomModel(inputs, outputs)\n",
        "model.compile(loss=\"mse\", metrics=[\"mae\"])\n",
        "\n",
        "# Evaluate with our custom test_step\n",
        "x = np.random.random((1000, 32))\n",
        "y = np.random.random((1000, 1))\n",
        "model.evaluate(x, y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xiE4ZsCtjI9B"
      },
      "source": [
        "## まとめ: エンドツーエンド GAN の例\n",
        "\n",
        "ここで学んだ全てを活用した、エンドツーエンドの例を見てみましょう。\n",
        "\n",
        "以下を検討してみましょう。\n",
        "\n",
        "- 28x28x1 の画像を生成するジェネレーターネットワーク。\n",
        "- 28x28x1 の画像を 2 つのクラス（「偽物」と「本物」）に分類するディスクリミネーターネットワーク。\n",
        "- それぞれに 1 つのオプティマイザ。\n",
        "- ディスクリミネーターをトレーニングする損失関数。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jyyxuepxgMuF"
      },
      "outputs": [],
      "source": [
        "from tensorflow.keras import layers\n",
        "\n",
        "# Create the discriminator\n",
        "discriminator = keras.Sequential(\n",
        "    [\n",
        "        keras.Input(shape=(28, 28, 1)),\n",
        "        layers.Conv2D(64, (3, 3), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Conv2D(128, (3, 3), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.GlobalMaxPooling2D(),\n",
        "        layers.Dense(1),\n",
        "    ],\n",
        "    name=\"discriminator\",\n",
        ")\n",
        "\n",
        "# Create the generator\n",
        "latent_dim = 128\n",
        "generator = keras.Sequential(\n",
        "    [\n",
        "        keras.Input(shape=(latent_dim,)),\n",
        "        # We want to generate 128 coefficients to reshape into a 7x7x128 map\n",
        "        layers.Dense(7 * 7 * 128),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Reshape((7, 7, 128)),\n",
        "        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n",
        "        layers.LeakyReLU(alpha=0.2),\n",
        "        layers.Conv2D(1, (7, 7), padding=\"same\", activation=\"sigmoid\"),\n",
        "    ],\n",
        "    name=\"generator\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cxFFOFm7xbCM"
      },
      "source": [
        "ここにフィーチャーコンプリートの GAN クラスがあります。`compile()`をオーバーライドして独自のシグネチャを使用することにより、GAN アルゴリズム全体を`train_step`の 17 行で実装しています。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wvS2v5pvGM7h"
      },
      "outputs": [],
      "source": [
        "class GAN(keras.Model):\n",
        "    def __init__(self, discriminator, generator, latent_dim):\n",
        "        super(GAN, self).__init__()\n",
        "        self.discriminator = discriminator\n",
        "        self.generator = generator\n",
        "        self.latent_dim = latent_dim\n",
        "\n",
        "    def compile(self, d_optimizer, g_optimizer, loss_fn):\n",
        "        super(GAN, self).compile()\n",
        "        self.d_optimizer = d_optimizer\n",
        "        self.g_optimizer = g_optimizer\n",
        "        self.loss_fn = loss_fn\n",
        "\n",
        "    def train_step(self, real_images):\n",
        "        if isinstance(real_images, tuple):\n",
        "            real_images = real_images[0]\n",
        "        # Sample random points in the latent space\n",
        "        batch_size = tf.shape(real_images)[0]\n",
        "        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n",
        "\n",
        "        # Decode them to fake images\n",
        "        generated_images = self.generator(random_latent_vectors)\n",
        "\n",
        "        # Combine them with real images\n",
        "        combined_images = tf.concat([generated_images, real_images], axis=0)\n",
        "\n",
        "        # Assemble labels discriminating real from fake images\n",
        "        labels = tf.concat(\n",
        "            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0\n",
        "        )\n",
        "        # Add random noise to the labels - important trick!\n",
        "        labels += 0.05 * tf.random.uniform(tf.shape(labels))\n",
        "\n",
        "        # Train the discriminator\n",
        "        with tf.GradientTape() as tape:\n",
        "            predictions = self.discriminator(combined_images)\n",
        "            d_loss = self.loss_fn(labels, predictions)\n",
        "        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)\n",
        "        self.d_optimizer.apply_gradients(\n",
        "            zip(grads, self.discriminator.trainable_weights)\n",
        "        )\n",
        "\n",
        "        # Sample random points in the latent space\n",
        "        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))\n",
        "\n",
        "        # Assemble labels that say \"all real images\"\n",
        "        misleading_labels = tf.zeros((batch_size, 1))\n",
        "\n",
        "        # Train the generator (note that we should *not* update the weights\n",
        "        # of the discriminator)!\n",
        "        with tf.GradientTape() as tape:\n",
        "            predictions = self.discriminator(self.generator(random_latent_vectors))\n",
        "            g_loss = self.loss_fn(misleading_labels, predictions)\n",
        "        grads = tape.gradient(g_loss, self.generator.trainable_weights)\n",
        "        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))\n",
        "        return {\"d_loss\": d_loss, \"g_loss\": g_loss}\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FGTUQysnjlsX"
      },
      "source": [
        "テストドライブしてみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mPJp4mErKaq1"
      },
      "outputs": [],
      "source": [
        "# Prepare the dataset. We use both the training &amp; test MNIST digits.\n",
        "batch_size = 64\n",
        "(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\n",
        "all_digits = np.concatenate([x_train, x_test])\n",
        "all_digits = all_digits.astype(\"float32\") / 255.0\n",
        "all_digits = np.reshape(all_digits, (-1, 28, 28, 1))\n",
        "dataset = tf.data.Dataset.from_tensor_slices(all_digits)\n",
        "dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)\n",
        "\n",
        "gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)\n",
        "gan.compile(\n",
        "    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),\n",
        "    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),\n",
        "    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        ")\n",
        "\n",
        "# To limit execution time, we only train on 100 batches. You can train on\n",
        "# the entire dataset. You will need about 20 epochs to get nice results.\n",
        "gan.fit(dataset.take(100), epochs=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZTGksVTeL0qk"
      },
      "source": [
        "ディープラーニングの背景にある考え方は単純です。実装もそうあるべきだと思います。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "customizing_what_happens_in_fit.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
